(* ========================================================================= *)
(* NORMALIZING FORMULAS                                                      *)
(* Copyright (c) 2001 Joe Leslie-Hurd, distributed under the BSD License     *)
(* ========================================================================= *)

structure Normalize :> Normalize =
struct

open Useful;

(* ------------------------------------------------------------------------- *)
(* Constants.                                                                *)
(* ------------------------------------------------------------------------- *)

val prefix = "FOFtoCNF";

val skolemPrefix = "skolem" ^ prefix;

val definitionPrefix = "definition" ^ prefix;

(* ------------------------------------------------------------------------- *)
(* Storing huge real numbers as their log.                                   *)
(* ------------------------------------------------------------------------- *)

datatype logReal = LogReal of real;

fun compareLogReal (LogReal logX, LogReal logY) =
    Real.compare (logX,logY);

val zeroLogReal = LogReal ~1.0;

val oneLogReal = LogReal 0.0;

local
  fun isZero logX = logX < 0.0;

  (* Assume logX >= logY >= 0.0 *)
  fun add logX logY = logX + Math.ln (1.0 + Math.exp (logY - logX));
in
  fun isZeroLogReal (LogReal logX) = isZero logX;

  fun multiplyLogReal (LogReal logX) (LogReal logY) =
      if isZero logX orelse isZero logY then zeroLogReal
      else LogReal (logX + logY);

  fun addLogReal (lx as LogReal logX) (ly as LogReal logY) =
      if isZero logX then ly
      else if isZero logY then lx
      else if logX < logY then LogReal (add logY logX)
      else LogReal (add logX logY);

  fun withinRelativeLogReal logDelta (LogReal logX) (LogReal logY) =
      isZero logX orelse
      (not (isZero logY) andalso logX < logY + logDelta);
end;

fun toStringLogReal (LogReal logX) = Real.toString logX;

(* ------------------------------------------------------------------------- *)
(* Counting the clauses that would be generated by conjunctive normal form.  *)
(* ------------------------------------------------------------------------- *)

val countLogDelta = 0.01;

datatype count = Count of {positive : logReal, negative : logReal};

fun countCompare (count1,count2) =
    let
      val Count {positive = p1, negative = _} = count1
      and Count {positive = p2, negative = _} = count2
    in
      compareLogReal (p1,p2)
    end;

fun countNegate (Count {positive = p, negative = n}) =
    Count {positive = n, negative = p};

fun countLeqish count1 count2 =
    let
      val Count {positive = p1, negative = _} = count1
      and Count {positive = p2, negative = _} = count2
    in
      withinRelativeLogReal countLogDelta p1 p2
    end;

(*MetisDebug
fun countEqualish count1 count2 =
    countLeqish count1 count2 andalso
    countLeqish count2 count1;

fun countEquivish count1 count2 =
    countEqualish count1 count2 andalso
    countEqualish (countNegate count1) (countNegate count2);
*)

val countTrue = Count {positive = zeroLogReal, negative = oneLogReal};

val countFalse = Count {positive = oneLogReal, negative = zeroLogReal};

val countLiteral = Count {positive = oneLogReal, negative = oneLogReal};

fun countAnd2 (count1,count2) =
    let
      val Count {positive = p1, negative = n1} = count1
      and Count {positive = p2, negative = n2} = count2
      val p = addLogReal p1 p2
      and n = multiplyLogReal n1 n2
    in
      Count {positive = p, negative = n}
    end;

fun countOr2 (count1,count2) =
    let
      val Count {positive = p1, negative = n1} = count1
      and Count {positive = p2, negative = n2} = count2
      val p = multiplyLogReal p1 p2
      and n = addLogReal n1 n2
    in
      Count {positive = p, negative = n}
    end;

(* Whether countXor2 is associative or not is an open question. *)

fun countXor2 (count1,count2) =
    let
      val Count {positive = p1, negative = n1} = count1
      and Count {positive = p2, negative = n2} = count2
      val p = addLogReal (multiplyLogReal p1 p2) (multiplyLogReal n1 n2)
      and n = addLogReal (multiplyLogReal p1 n2) (multiplyLogReal n1 p2)
    in
      Count {positive = p, negative = n}
    end;

fun countDefinition body_count = countXor2 (countLiteral,body_count);

val countToString =
    let
      val rToS = toStringLogReal
    in
      fn Count {positive = p, negative = n} =>
         "(+" ^ rToS p ^ ",-" ^ rToS n ^ ")"
    end;

val ppCount = Print.ppMap countToString Print.ppString;

(* ------------------------------------------------------------------------- *)
(* A type of normalized formula.                                             *)
(* ------------------------------------------------------------------------- *)

datatype formula =
    True
  | False
  | Literal of NameSet.set * Literal.literal
  | And of NameSet.set * count * formula Set.set
  | Or of NameSet.set * count * formula Set.set
  | Xor of NameSet.set * count * bool * formula Set.set
  | Exists of NameSet.set * count * NameSet.set * formula
  | Forall of NameSet.set * count * NameSet.set * formula;

fun compare f1_f2 =
    if Portable.pointerEqual f1_f2 then EQUAL
    else
      case f1_f2 of
        (True,True) => EQUAL
      | (True,_) => LESS
      | (_,True) => GREATER
      | (False,False) => EQUAL
      | (False,_) => LESS
      | (_,False) => GREATER
      | (Literal (_,l1), Literal (_,l2)) => Literal.compare (l1,l2)
      | (Literal _, _) => LESS
      | (_, Literal _) => GREATER
      | (And (_,_,s1), And (_,_,s2)) => Set.compare (s1,s2)
      | (And _, _) => LESS
      | (_, And _) => GREATER
      | (Or (_,_,s1), Or (_,_,s2)) => Set.compare (s1,s2)
      | (Or _, _) => LESS
      | (_, Or _) => GREATER
      | (Xor (_,_,p1,s1), Xor (_,_,p2,s2)) =>
        (case boolCompare (p1,p2) of
           LESS => LESS
         | EQUAL => Set.compare (s1,s2)
         | GREATER => GREATER)
      | (Xor _, _) => LESS
      | (_, Xor _) => GREATER
      | (Exists (_,_,n1,f1), Exists (_,_,n2,f2)) =>
        (case NameSet.compare (n1,n2) of
           LESS => LESS
         | EQUAL => compare (f1,f2)
         | GREATER => GREATER)
      | (Exists _, _) => LESS
      | (_, Exists _) => GREATER
      | (Forall (_,_,n1,f1), Forall (_,_,n2,f2)) =>
        (case NameSet.compare (n1,n2) of
           LESS => LESS
         | EQUAL => compare (f1,f2)
         | GREATER => GREATER);

val empty = Set.empty compare;

val singleton = Set.singleton compare;

local
  fun neg True = False
    | neg False = True
    | neg (Literal (fv,lit)) = Literal (fv, Literal.negate lit)
    | neg (And (fv,c,s)) = Or (fv, countNegate c, neg_set s)
    | neg (Or (fv,c,s)) = And (fv, countNegate c, neg_set s)
    | neg (Xor (fv,c,p,s)) = Xor (fv, c, not p, s)
    | neg (Exists (fv,c,n,f)) = Forall (fv, countNegate c, n, neg f)
    | neg (Forall (fv,c,n,f)) = Exists (fv, countNegate c, n, neg f)

  and neg_set s = Set.foldl neg_elt empty s

  and neg_elt (f,s) = Set.add s (neg f);
in
  val negate = neg;

  val negateSet = neg_set;
end;

fun negateMember x s = Set.member (negate x) s;

local
  fun member s x = negateMember x s;
in
  fun negateDisjoint s1 s2 =
      if Set.size s1 < Set.size s2 then not (Set.exists (member s2) s1)
      else not (Set.exists (member s1) s2);
end;

fun polarity True = true
  | polarity False = false
  | polarity (Literal (_,(pol,_))) = not pol
  | polarity (And _) = true
  | polarity (Or _) = false
  | polarity (Xor (_,_,pol,_)) = pol
  | polarity (Exists _) = true
  | polarity (Forall _) = false;

(*MetisDebug
val polarity = fn f =>
    let
      val res1 = compare (f, negate f) = LESS
      val res2 = polarity f
      val _ = res1 = res2 orelse raise Bug "polarity"
    in
      res2
    end;
*)

fun applyPolarity true fm = fm
  | applyPolarity false fm = negate fm;

fun freeVars True = NameSet.empty
  | freeVars False = NameSet.empty
  | freeVars (Literal (fv,_)) = fv
  | freeVars (And (fv,_,_)) = fv
  | freeVars (Or (fv,_,_)) = fv
  | freeVars (Xor (fv,_,_,_)) = fv
  | freeVars (Exists (fv,_,_,_)) = fv
  | freeVars (Forall (fv,_,_,_)) = fv;

fun freeIn v fm = NameSet.member v (freeVars fm);

val freeVarsSet =
    let
      fun free (fm,acc) = NameSet.union (freeVars fm) acc
    in
      Set.foldl free NameSet.empty
    end;

fun count True = countTrue
  | count False = countFalse
  | count (Literal _) = countLiteral
  | count (And (_,c,_)) = c
  | count (Or (_,c,_)) = c
  | count (Xor (_,c,p,_)) = if p then c else countNegate c
  | count (Exists (_,c,_,_)) = c
  | count (Forall (_,c,_,_)) = c;

val countAndSet =
    let
      fun countAnd (fm,c) = countAnd2 (count fm, c)
    in
      Set.foldl countAnd countTrue
    end;

val countOrSet =
    let
      fun countOr (fm,c) = countOr2 (count fm, c)
    in
      Set.foldl countOr countFalse
    end;

val countXorSet =
    let
      fun countXor (fm,c) = countXor2 (count fm, c)
    in
      Set.foldl countXor countFalse
    end;

fun And2 (False,_) = False
  | And2 (_,False) = False
  | And2 (True,f2) = f2
  | And2 (f1,True) = f1
  | And2 (f1,f2) =
    let
      val (fv1,c1,s1) =
          case f1 of
            And fv_c_s => fv_c_s
          | _ => (freeVars f1, count f1, singleton f1)

      and (fv2,c2,s2) =
          case f2 of
            And fv_c_s => fv_c_s
          | _ => (freeVars f2, count f2, singleton f2)
    in
      if not (negateDisjoint s1 s2) then False
      else
        let
          val s = Set.union s1 s2
        in
          case Set.size s of
            0 => True
          | 1 => Set.pick s
          | n =>
            if n = Set.size s1 + Set.size s2 then
              And (NameSet.union fv1 fv2, countAnd2 (c1,c2), s)
            else
              And (freeVarsSet s, countAndSet s, s)
        end
    end;

val AndList = List.foldl And2 True;

val AndSet = Set.foldl And2 True;

fun Or2 (True,_) = True
  | Or2 (_,True) = True
  | Or2 (False,f2) = f2
  | Or2 (f1,False) = f1
  | Or2 (f1,f2) =
    let
      val (fv1,c1,s1) =
          case f1 of
            Or fv_c_s => fv_c_s
          | _ => (freeVars f1, count f1, singleton f1)

      and (fv2,c2,s2) =
          case f2 of
            Or fv_c_s => fv_c_s
          | _ => (freeVars f2, count f2, singleton f2)
    in
      if not (negateDisjoint s1 s2) then True
      else
        let
          val s = Set.union s1 s2
        in
          case Set.size s of
            0 => False
          | 1 => Set.pick s
          | n =>
            if n = Set.size s1 + Set.size s2 then
              Or (NameSet.union fv1 fv2, countOr2 (c1,c2), s)
            else
              Or (freeVarsSet s, countOrSet s, s)
        end
    end;

val OrList = List.foldl Or2 False;

val OrSet = Set.foldl Or2 False;

fun pushOr2 (f1,f2) =
    let
      val s1 = case f1 of And (_,_,s) => s | _ => singleton f1
      and s2 = case f2 of And (_,_,s) => s | _ => singleton f2

      fun g x1 (x2,acc) = And2 (Or2 (x1,x2), acc)

      fun f (x1,acc) = Set.foldl (g x1) acc s2
    in
      Set.foldl f True s1
    end;

val pushOrList = List.foldl pushOr2 False;

local
  fun normalize fm =
      let
        val p = polarity fm
        val fm = applyPolarity p fm
      in
        (freeVars fm, count fm, p, singleton fm)
      end;
in
  fun Xor2 (False,f2) = f2
    | Xor2 (f1,False) = f1
    | Xor2 (True,f2) = negate f2
    | Xor2 (f1,True) = negate f1
    | Xor2 (f1,f2) =
      let
        val (fv1,c1,p1,s1) = case f1 of Xor x => x | _ => normalize f1
        and (fv2,c2,p2,s2) = case f2 of Xor x => x | _ => normalize f2

        val s = Set.symmetricDifference s1 s2

        val fm =
            case Set.size s of
              0 => False
            | 1 => Set.pick s
            | n =>
              if n = Set.size s1 + Set.size s2 then
                Xor (NameSet.union fv1 fv2, countXor2 (c1,c2), true, s)
              else
                Xor (freeVarsSet s, countXorSet s, true, s)

        val p = p1 = p2
      in
        applyPolarity p fm
      end;
end;

val XorList = List.foldl Xor2 False;

val XorSet = Set.foldl Xor2 False;

fun XorPolarityList (p,l) = applyPolarity p (XorList l);

fun XorPolaritySet (p,s) = applyPolarity p (XorSet s);

fun destXor (Xor (_,_,p,s)) =
    let
      val (fm1,s) = Set.deletePick s
      val fm2 =
          if Set.size s = 1 then applyPolarity p (Set.pick s)
          else Xor (freeVarsSet s, countXorSet s, p, s)
    in
      (fm1,fm2)
    end
  | destXor _ = raise Error "destXor";

fun pushXor fm =
    let
      val (f1,f2) = destXor fm
      val f1' = negate f1
      and f2' = negate f2
    in
      And2 (Or2 (f1,f2), Or2 (f1',f2'))
    end;

fun Exists1 (v,init_fm) =
    let
      fun exists_gen fm =
          let
            val fv = NameSet.delete (freeVars fm) v
            val c = count fm
            val n = NameSet.singleton v
          in
            Exists (fv,c,n,fm)
          end

      fun exists fm = if freeIn v fm then exists_free fm else fm

      and exists_free (Or (_,_,s)) = OrList (Set.transform exists s)
        | exists_free (fm as And (_,_,s)) =
          let
            val sv = Set.filter (freeIn v) s
          in
            if Set.size sv <> 1 then exists_gen fm
            else
              let
                val fm = Set.pick sv
                val s = Set.delete s fm
              in
                And2 (exists_free fm, AndSet s)
              end
          end
        | exists_free (Exists (fv,c,n,f)) =
          Exists (NameSet.delete fv v, c, NameSet.add n v, f)
        | exists_free fm = exists_gen fm
    in
      exists init_fm
    end;

fun ExistsList (vs,f) = List.foldl Exists1 f vs;

fun ExistsSet (n,f) = NameSet.foldl Exists1 f n;

fun Forall1 (v,init_fm) =
    let
      fun forall_gen fm =
          let
            val fv = NameSet.delete (freeVars fm) v
            val c = count fm
            val n = NameSet.singleton v
          in
            Forall (fv,c,n,fm)
          end

      fun forall fm = if freeIn v fm then forall_free fm else fm

      and forall_free (And (_,_,s)) = AndList (Set.transform forall s)
        | forall_free (fm as Or (_,_,s)) =
          let
            val sv = Set.filter (freeIn v) s
          in
            if Set.size sv <> 1 then forall_gen fm
            else
              let
                val fm = Set.pick sv
                val s = Set.delete s fm
              in
                Or2 (forall_free fm, OrSet s)
              end
          end
        | forall_free (Forall (fv,c,n,f)) =
          Forall (NameSet.delete fv v, c, NameSet.add n v, f)
        | forall_free fm = forall_gen fm
    in
      forall init_fm
    end;

fun ForallList (vs,f) = List.foldl Forall1 f vs;

fun ForallSet (n,f) = NameSet.foldl Forall1 f n;

fun generalize f = ForallSet (freeVars f, f);

local
  fun subst_fv fvSub =
      let
        fun add_fv (v,s) = NameSet.union (NameMap.get fvSub v) s
      in
        NameSet.foldl add_fv NameSet.empty
      end;

  fun subst_rename (v,(avoid,bv,sub,domain,fvSub)) =
      let
        val v' = Term.variantPrime avoid v
        val avoid = NameSet.add avoid v'
        val bv = NameSet.add bv v'
        val sub = Subst.insert sub (v, Term.Var v')
        val domain = NameSet.add domain v
        val fvSub = NameMap.insert fvSub (v, NameSet.singleton v')
      in
        (avoid,bv,sub,domain,fvSub)
      end;

  fun subst_check sub domain fvSub fm =
      let
        val domain = NameSet.intersect domain (freeVars fm)
      in
        if NameSet.null domain then fm
        else subst_domain sub domain fvSub fm
      end

  and subst_domain sub domain fvSub fm =
      case fm of
        Literal (fv,lit) =>
        let
          val fv = NameSet.difference fv domain
          val fv = NameSet.union fv (subst_fv fvSub domain)
          val lit = Literal.subst sub lit
        in
          Literal (fv,lit)
        end
      | And (_,_,s) =>
        AndList (Set.transform (subst_check sub domain fvSub) s)
      | Or (_,_,s) =>
        OrList (Set.transform (subst_check sub domain fvSub) s)
      | Xor (_,_,p,s) =>
        XorPolarityList (p, Set.transform (subst_check sub domain fvSub) s)
      | Exists fv_c_n_f => subst_quant Exists sub domain fvSub fv_c_n_f
      | Forall fv_c_n_f => subst_quant Forall sub domain fvSub fv_c_n_f
      | _ => raise Bug "subst_domain"

  and subst_quant quant sub domain fvSub (fv,c,bv,fm) =
      let
        val sub_fv = subst_fv fvSub domain
        val fv = NameSet.union sub_fv (NameSet.difference fv domain)
        val captured = NameSet.intersect bv sub_fv
        val bv = NameSet.difference bv captured
        val avoid = NameSet.union fv bv
        val (_,bv,sub,domain,fvSub) =
            NameSet.foldl subst_rename (avoid,bv,sub,domain,fvSub) captured
        val fm = subst_domain sub domain fvSub fm
      in
        quant (fv,c,bv,fm)
      end;
in
  fun subst sub =
      let
        fun mk_dom (v,tm,(d,fv)) =
            (NameSet.add d v, NameMap.insert fv (v, Term.freeVars tm))

        val domain_fvSub = (NameSet.empty, NameMap.new ())
        val (domain,fvSub) = Subst.foldl mk_dom domain_fvSub sub
      in
        subst_check sub domain fvSub
      end;
end;

fun fromFormula fm =
    case fm of
      Formula.True => True
    | Formula.False => False
    | Formula.Atom atm => Literal (Atom.freeVars atm, (true,atm))
    | Formula.Not p => negateFromFormula p
    | Formula.And (p,q) => And2 (fromFormula p, fromFormula q)
    | Formula.Or (p,q) => Or2 (fromFormula p, fromFormula q)
    | Formula.Imp (p,q) => Or2 (negateFromFormula p, fromFormula q)
    | Formula.Iff (p,q) => Xor2 (negateFromFormula p, fromFormula q)
    | Formula.Forall (v,p) => Forall1 (v, fromFormula p)
    | Formula.Exists (v,p) => Exists1 (v, fromFormula p)

and negateFromFormula fm =
    case fm of
      Formula.True => False
    | Formula.False => True
    | Formula.Atom atm => Literal (Atom.freeVars atm, (false,atm))
    | Formula.Not p => fromFormula p
    | Formula.And (p,q) => Or2 (negateFromFormula p, negateFromFormula q)
    | Formula.Or (p,q) => And2 (negateFromFormula p, negateFromFormula q)
    | Formula.Imp (p,q) => And2 (fromFormula p, negateFromFormula q)
    | Formula.Iff (p,q) => Xor2 (fromFormula p, fromFormula q)
    | Formula.Forall (v,p) => Exists1 (v, negateFromFormula p)
    | Formula.Exists (v,p) => Forall1 (v, negateFromFormula p);

local
  fun lastElt (s : formula Set.set) =
      case Set.findr (K true) s of
        NONE => raise Bug "lastElt: empty set"
      | SOME fm => fm;

  fun negateLastElt s =
      let
        val fm = lastElt s
      in
        Set.add (Set.delete s fm) (negate fm)
      end;

  fun form fm =
      case fm of
        True => Formula.True
      | False => Formula.False
      | Literal (_,lit) => Literal.toFormula lit
      | And (_,_,s) => Formula.listMkConj (Set.transform form s)
      | Or (_,_,s) => Formula.listMkDisj (Set.transform form s)
      | Xor (_,_,p,s) => xorForm p s
      | Exists (_,_,n,f) => Formula.listMkExists (NameSet.toList n, form f)
      | Forall (_,_,n,f) => Formula.listMkForall (NameSet.toList n, form f)

  (* To convert a Xor set to an Iff list we need to know   *)
  (* whether the size of the set is even or odd:           *)
  (*                                                       *)
  (*                   b XOR a = b <=> ~a                  *)
  (*             c XOR b XOR a = c <=> b <=> a             *)
  (*       d XOR c XOR b XOR a = d <=> c <=> b <=> ~a      *)
  (* e XOR d XOR c XOR b XOR a = e <=> d <=> c <=> b <=> a *)
  and xorForm p s =
      let
        val p = if Set.size s mod 2 = 0 then not p else p
        val s = if p then s else negateLastElt s
      in
        Formula.listMkEquiv (Set.transform form s)
      end;
in
  val toFormula = form;
end;

fun toLiteral (Literal (_,lit)) = lit
  | toLiteral _ = raise Error "Normalize.toLiteral";

local
  fun addLiteral (l,s) = LiteralSet.add s (toLiteral l);
in
  fun toClause False = LiteralSet.empty
    | toClause (Or (_,_,s)) = Set.foldl addLiteral LiteralSet.empty s
    | toClause l = LiteralSet.singleton (toLiteral l);
end;

val pp = Print.ppMap toFormula Formula.pp;

val toString = Print.toString pp;

(* ------------------------------------------------------------------------- *)
(* Negation normal form.                                                     *)
(* ------------------------------------------------------------------------- *)

fun nnf fm = toFormula (fromFormula fm);

(* ------------------------------------------------------------------------- *)
(* Basic conjunctive normal form.                                            *)
(* ------------------------------------------------------------------------- *)

local
  val counter : int StringMap.map ref = ref (StringMap.new ());

  fun new n () =
      let
        val ref m = counter
        val s = Name.toString n
        val i = Option.getOpt (StringMap.peek m s, 0)
        val () = counter := StringMap.insert m (s, i + 1)
        val i = if i = 0 then "" else "_" ^ Int.toString i
        val s = skolemPrefix ^ "_" ^ s ^ i
      in
        Name.fromString s
      end;
in
  fun newSkolemFunction n = Portable.critical (new n) ();
end;

fun skolemize fv bv fm =
    let
      val fv = NameSet.transform Term.Var fv

      fun mk (v,s) = Subst.insert s (v, Term.Fn (newSkolemFunction v, fv))
    in
      subst (NameSet.foldl mk Subst.empty bv) fm
    end;

local
  fun rename avoid fv bv fm =
      let
        val captured = NameSet.intersect avoid bv
      in
        if NameSet.null captured then fm
        else
          let
            fun ren (v,(a,s)) =
                let
                  val v' = Term.variantPrime a v
                in
                  (NameSet.add a v', Subst.insert s (v, Term.Var v'))
                end

            val avoid = NameSet.union (NameSet.union avoid fv) bv

            val (_,sub) = NameSet.foldl ren (avoid,Subst.empty) captured
          in
            subst sub fm
          end
      end;

  fun cnfFm avoid fm =
(*MetisTrace5
      let
        val fm' = cnfFm' avoid fm
        val () = Print.trace pp "Normalize.cnfFm: fm" fm
        val () = Print.trace pp "Normalize.cnfFm: fm'" fm'
      in
        fm'
      end
  and cnfFm' avoid fm =
*)
      case fm of
        True => True
      | False => False
      | Literal _ => fm
      | And (_,_,s) => AndList (Set.transform (cnfFm avoid) s)
      | Or (fv,_,s) =>
        let
          val avoid = NameSet.union avoid fv
          val (fms,_) = Set.foldl cnfOr ([],avoid) s
        in
          pushOrList fms
        end
      | Xor _ => cnfFm avoid (pushXor fm)
      | Exists (fv,_,n,f) => cnfFm avoid (skolemize fv n f)
      | Forall (fv,_,n,f) => cnfFm avoid (rename avoid fv n f)

  and cnfOr (fm,(fms,avoid)) =
      let
        val fm = cnfFm avoid fm
        val fms = fm :: fms
        val avoid = NameSet.union avoid (freeVars fm)
      in
        (fms,avoid)
      end;
in
  val basicCnf = cnfFm NameSet.empty;
end;

(* ------------------------------------------------------------------------- *)
(* Finding the formula definition that minimizes the number of clauses.      *)
(* ------------------------------------------------------------------------- *)

local
  type best = count * formula option;

  fun minBreak countClauses fm best =
      case fm of
        True => best
      | False => best
      | Literal _ => best
      | And (_,_,s) =>
        minBreakSet countClauses countAnd2 countTrue AndSet s best
      | Or (_,_,s) =>
        minBreakSet countClauses countOr2 countFalse OrSet s best
      | Xor (_,_,_,s) =>
        minBreakSet countClauses countXor2 countFalse XorSet s best
      | Exists (_,_,_,f) => minBreak countClauses f best
      | Forall (_,_,_,f) => minBreak countClauses f best

  and minBreakSet countClauses count2 count0 mkSet fmSet best =
      let
        fun cumulatives fms =
            let
              fun fwd (fm,(c1,s1,l)) =
                  let
                    val c1' = count2 (count fm, c1)
                    and s1' = Set.add s1 fm
                  in
                    (c1', s1', (c1,s1,fm) :: l)
                  end

              fun bwd ((c1,s1,fm),(c2,s2,l)) =
                  let
                    val c2' = count2 (count fm, c2)
                    and s2' = Set.add s2 fm
                  in
                    (c2', s2', (c1,s1,fm,c2,s2) :: l)
                  end

              val (c1,_,fms) = List.foldl fwd (count0,empty,[]) fms
              val (c2,_,fms) = List.foldl bwd (count0,empty,[]) fms

(*MetisDebug
              val _ = countEquivish c1 c2 orelse
                      raise Bug ("cumulativeCounts: c1 = " ^ countToString c1 ^
                                 ", c2 = " ^ countToString c2)
*)
            in
              fms
            end

        fun breakSing ((c1,_,fm,c2,_),best) =
            let
              val cFms = count2 (c1,c2)

              fun countCls cFm = countClauses (count2 (cFms,cFm))
            in
              minBreak countCls fm best
            end

        val breakSet1 =
            let
              fun break c1 s1 fm c2 (best as (bcl,_)) =
                  if Set.null s1 then best
                  else
                    let
                      val cDef = countDefinition (countXor2 (c1, count fm))
                      val cFm = count2 (countLiteral,c2)
                      val cl = countAnd2 (cDef, countClauses cFm)
                      val noBetter = countLeqish bcl cl
                    in
                      if noBetter then best
                      else (cl, SOME (mkSet (Set.add s1 fm)))
                    end
            in
              fn ((c1,s1,fm,c2,s2),best) =>
                 break c1 s1 fm c2 (break c2 s2 fm c1 best)
            end

        val fms = Set.toList fmSet

        fun breakSet measure best =
            let
              val fms = sortMap (measure o count) countCompare fms
            in
              List.foldl breakSet1 best (cumulatives fms)
            end

        val best = List.foldl breakSing best (cumulatives fms)
        val best = breakSet I best
        val best = breakSet countNegate best
        val best = breakSet countClauses best
      in
        best
      end
in
  fun minimumDefinition fm =
      let
        val cl = count fm
      in
        if countLeqish cl countLiteral then NONE
        else
          let
            val (cl',def) = minBreak I fm (cl,NONE)
(*MetisTrace1
            val () =
                case def of
                  NONE => ()
                | SOME d =>
                  Print.trace pp ("defCNF: before = " ^ countToString cl ^
                                  ", after = " ^ countToString cl' ^
                                  ", definition") d
*)
          in
            def
          end
      end;
end;

(* ------------------------------------------------------------------------- *)
(* Conjunctive normal form derivations.                                      *)
(* ------------------------------------------------------------------------- *)

datatype thm = Thm of formula * inference

and inference =
    Axiom of Formula.formula
  | Definition of string * Formula.formula
  | Simplify of thm * thm list
  | Conjunct of thm
  | Specialize of thm
  | Skolemize of thm
  | Clausify of thm;

fun parentsInference inf =
    case inf of
      Axiom _ => []
    | Definition _ => []
    | Simplify (th,ths) => th :: ths
    | Conjunct th => [th]
    | Specialize th => [th]
    | Skolemize th => [th]
    | Clausify th => [th];

fun compareThm (Thm (fm1,_), Thm (fm2,_)) = compare (fm1,fm2);

fun parentsThm (Thm (_,inf)) = parentsInference inf;

fun mkAxiom fm = Thm (fromFormula fm, Axiom fm);

fun destThm (Thm (fm,inf)) = (toFormula fm, inf);

local
  val emptyProved : (thm,Formula.formula) Map.map = Map.new compareThm;

  fun isProved proved th = Map.inDomain th proved;

  fun isUnproved proved th = not (isProved proved th);

  fun lookupProved proved th =
      case Map.peek proved th of
        SOME fm => fm
      | NONE => raise Bug "Normalize.lookupProved";

  fun prove acc proved ths =
      case ths of
        [] => List.rev acc
      | th :: ths' =>
        if isProved proved th then prove acc proved ths'
        else
          let
            val pars = parentsThm th

            val deps = List.filter (isUnproved proved) pars
          in
            if List.null deps then
              let
                val (fm,inf) = destThm th

                val fms = List.map (lookupProved proved) pars

                val acc = (fm,inf,fms) :: acc

                val proved = Map.insert proved (th,fm)
              in
                prove acc proved ths'
              end
            else
              let
                val ths = deps @ ths
              in
                prove acc proved ths
              end
          end;
in
  val proveThms = prove [] emptyProved;
end;

fun toStringInference inf =
    case inf of
      Axiom _ => "Axiom"
    | Definition _ => "Definition"
    | Simplify _ => "Simplify"
    | Conjunct _ => "Conjunct"
    | Specialize _ => "Specialize"
    | Skolemize _ => "Skolemize"
    | Clausify _ => "Clausify";

val ppInference = Print.ppMap toStringInference Print.ppString;

(* ------------------------------------------------------------------------- *)
(* Simplifying with definitions.                                             *)
(* ------------------------------------------------------------------------- *)

datatype simplify =
    Simp of
      {formula : (formula, formula * thm) Map.map,
       andSet : (formula Set.set * formula * thm) list,
       orSet : (formula Set.set * formula * thm) list,
       xorSet : (formula Set.set * formula * thm) list};

val simplifyEmpty =
    Simp
      {formula = Map.new compare,
       andSet = [],
       orSet = [],
       xorSet = []};

local
  fun simpler fm s =
      Set.size s <> 1 orelse
      case Set.pick s of
        True => false
      | False => false
      | Literal _ => false
      | _ => true;

  fun addSet set_defs body_def =
      let
        fun def_body_size (body,_,_) = Set.size body

        val body_size = def_body_size body_def

        val (body,_,_) = body_def

        fun add acc [] = List.revAppend (acc,[body_def])
          | add acc (l as (bd as (b,_,_)) :: bds) =
            case Int.compare (def_body_size bd, body_size) of
              LESS => List.revAppend (acc, body_def :: l)
            | EQUAL =>
              if Set.equal b body then List.revAppend (acc,l)
              else add (bd :: acc) bds
            | GREATER => add (bd :: acc) bds
      in
        add [] set_defs
      end;

  fun add simp (body,False,th) = add simp (negate body, True, th)
    | add simp (True,_,_) = simp
    | add (Simp {formula,andSet,orSet,xorSet}) (And (_,_,s), def, th) =
      let
        val andSet = addSet andSet (s,def,th)
        and orSet = addSet orSet (negateSet s, negate def, th)
      in
        Simp
          {formula = formula,
           andSet = andSet,
           orSet = orSet,
           xorSet = xorSet}
      end
    | add (Simp {formula,andSet,orSet,xorSet}) (Or (_,_,s), def, th) =
      let
        val orSet = addSet orSet (s,def,th)
        and andSet = addSet andSet (negateSet s, negate def, th)
      in
        Simp
          {formula = formula,
           andSet = andSet,
           orSet = orSet,
           xorSet = xorSet}
      end
    | add simp (Xor (_,_,p,s), def, th) =
      let
        val simp = addXorSet simp (s, applyPolarity p def, th)
      in
        case def of
          True =>
          let
            fun addXorLiteral (fm as Literal _, simp) =
                let
                  val s = Set.delete s fm
                in
                  if not (simpler fm s) then simp
                  else addXorSet simp (s, applyPolarity (not p) fm, th)
                end
              | addXorLiteral (_,simp) = simp
          in
            Set.foldl addXorLiteral simp s
          end
        | _ => simp
      end
    | add (simp as Simp {formula,andSet,orSet,xorSet}) (body,def,th) =
      if Map.inDomain body formula then simp
      else
        let
          val formula = Map.insert formula (body,(def,th))
          val formula = Map.insert formula (negate body, (negate def, th))
        in
          Simp
            {formula = formula,
             andSet = andSet,
             orSet = orSet,
             xorSet = xorSet}
        end

  and addXorSet (simp as Simp {formula,andSet,orSet,xorSet}) (s,def,th) =
      if Set.size s = 1 then add simp (Set.pick s, def, th)
      else
        let
          val xorSet = addSet xorSet (s,def,th)
        in
          Simp
            {formula = formula,
             andSet = andSet,
             orSet = orSet,
             xorSet = xorSet}
        end;
in
  fun simplifyAdd simp (th as Thm (fm,_)) = add simp (fm,True,th);
end;

local
  fun simplifySet set_defs set =
      let
        fun pred (s,_,_) = Set.subset s set
      in
        case List.find pred set_defs of
          NONE => NONE
        | SOME (s,f,th) =>
          let
            val set = Set.add (Set.difference set s) f
          in
            SOME (set,th)
          end
      end;
in
  fun simplify (Simp {formula,andSet,orSet,xorSet}) =
      let
        fun simp fm inf =
            case simp_sub fm inf of
              NONE => simp_top fm inf
            | SOME (fm,inf) => try_simp_top fm inf

        and try_simp_top fm inf =
            case simp_top fm inf of
              NONE => SOME (fm,inf)
            | x => x

        and simp_top fm inf =
            case fm of
              And (_,_,s) =>
              (case simplifySet andSet s of
                 NONE => NONE
               | SOME (s,th) =>
                 let
                   val fm = AndSet s
                   val inf = th :: inf
                 in
                   try_simp_top fm inf
                 end)
            | Or (_,_,s) =>
              (case simplifySet orSet s of
                 NONE => NONE
               | SOME (s,th) =>
                 let
                   val fm = OrSet s
                   val inf = th :: inf
                 in
                   try_simp_top fm inf
                 end)
            | Xor (_,_,p,s) =>
              (case simplifySet xorSet s of
                 NONE => NONE
               | SOME (s,th) =>
                 let
                   val fm = XorPolaritySet (p,s)
                   val inf = th :: inf
                 in
                   try_simp_top fm inf
                 end)
            | _ =>
              (case Map.peek formula fm of
                 NONE => NONE
               | SOME (fm,th) =>
                 let
                   val inf = th :: inf
                 in
                   try_simp_top fm inf
                 end)

        and simp_sub fm inf =
            case fm of
              And (_,_,s) =>
              (case simp_set s inf of
                 NONE => NONE
               | SOME (l,inf) => SOME (AndList l, inf))
            | Or (_,_,s) =>
              (case simp_set s inf of
                 NONE => NONE
               | SOME (l,inf) => SOME (OrList l, inf))
            | Xor (_,_,p,s) =>
              (case simp_set s inf of
                 NONE => NONE
               | SOME (l,inf) => SOME (XorPolarityList (p,l), inf))
            | Exists (_,_,n,f) =>
              (case simp f inf of
                 NONE => NONE
               | SOME (f,inf) => SOME (ExistsSet (n,f), inf))
            | Forall (_,_,n,f) =>
              (case simp f inf of
                 NONE => NONE
               | SOME (f,inf) => SOME (ForallSet (n,f), inf))
            | _ => NONE

        and simp_set s inf =
            let
              val (changed,l,inf) = Set.foldr simp_set_elt (false,[],inf) s
            in
              if changed then SOME (l,inf) else NONE
            end

        and simp_set_elt (fm,(changed,l,inf)) =
            case simp fm inf of
              NONE => (changed, fm :: l, inf)
            | SOME (fm,inf) => (true, fm :: l, inf)
      in
        fn th as Thm (fm,_) =>
           case simp fm [] of
             SOME (fm,ths) =>
             let
               val inf = Simplify (th,ths)
             in
               Thm (fm,inf)
             end
           | NONE => th
      end;
end;

(*MetisTrace2
val simplify = fn simp => fn th as Thm (fm,_) =>
    let
      val th' as Thm (fm',_) = simplify simp th
      val () = if compare (fm,fm') = EQUAL then ()
               else (Print.trace pp "Normalize.simplify: fm" fm;
                     Print.trace pp "Normalize.simplify: fm'" fm')
    in
      th'
    end;
*)

(* ------------------------------------------------------------------------- *)
(* Definitions.                                                              *)
(* ------------------------------------------------------------------------- *)

local
  val counter : int ref = ref 0;

  fun new () =
      let
        val ref i = counter
        val () = counter := i + 1
      in
        definitionPrefix ^ "_" ^ Int.toString i
      end;
in
  fun newDefinitionRelation () = Portable.critical new ();
end;

fun newDefinition def =
    let
      val fv = freeVars def
      val rel = newDefinitionRelation ()
      val atm = (Name.fromString rel, NameSet.transform Term.Var fv)
      val fm = Formula.Iff (Formula.Atom atm, toFormula def)
      val fm = Formula.setMkForall (fv,fm)
      val inf = Definition (rel,fm)
      val lit = Literal (fv,(false,atm))
      val fm = Xor2 (lit,def)
    in
      Thm (fm,inf)
    end;

(* ------------------------------------------------------------------------- *)
(* Definitional conjunctive normal form.                                     *)
(* ------------------------------------------------------------------------- *)

datatype cnf =
    ConsistentCnf of simplify
  | InconsistentCnf;

val initialCnf = ConsistentCnf simplifyEmpty;

local
  fun def_cnf_inconsistent th =
      let
        val cls = [(LiteralSet.empty,th)]
      in
        (cls,InconsistentCnf)
      end;

  fun def_cnf_clause inf (fm,acc) =
      let
        val cl = toClause fm
        val th = Thm (fm,inf)
      in
        (cl,th) :: acc
      end
(*MetisDebug
      handle Error err =>
        (Print.trace pp "Normalize.addCnf.def_cnf_clause: fm" fm;
         raise Bug ("Normalize.addCnf.def_cnf_clause: " ^ err));
*)

  fun def_cnf cls simp ths =
      case ths of
        [] => (cls, ConsistentCnf simp)
      | th :: ths => def_cnf_formula cls simp (simplify simp th) ths

  and def_cnf_formula cls simp (th as Thm (fm,_)) ths =
      case fm of
        True => def_cnf cls simp ths
      | False => def_cnf_inconsistent th
      | And (_,_,s) =>
        let
          fun add (f,z) = Thm (f, Conjunct th) :: z
        in
          def_cnf cls simp (Set.foldr add ths s)
        end
      | Exists (fv,_,n,f) =>
        let
          val th = Thm (skolemize fv n f, Skolemize th)
        in
          def_cnf_formula cls simp th ths
        end
      | Forall (_,_,_,f) =>
        let
          val th = Thm (f, Specialize th)
        in
          def_cnf_formula cls simp th ths
        end
      | _ =>
        case minimumDefinition fm of
          SOME def =>
          let
            val ths = th :: ths
            val th = newDefinition def
          in
            def_cnf_formula cls simp th ths
          end
        | NONE =>
          let
            val simp = simplifyAdd simp th

            val fm = basicCnf fm

            val inf = Clausify th
          in
            case fm of
              True => def_cnf cls simp ths
            | False => def_cnf_inconsistent (Thm (fm,inf))
            | And (_,_,s) =>
              let
                val inf = Conjunct (Thm (fm,inf))
                val cls = Set.foldl (def_cnf_clause inf) cls s
              in
                def_cnf cls simp ths
              end
            | fm => def_cnf (def_cnf_clause inf (fm,cls)) simp ths
          end;
in
  fun addCnf th cnf =
      case cnf of
        ConsistentCnf simp => def_cnf [] simp [th]
      | InconsistentCnf => ([],cnf);
end;

local
  fun add (th,(cls,cnf)) =
      let
        val (cls',cnf) = addCnf th cnf
      in
        (cls' @ cls, cnf)
      end;
in
  fun proveCnf ths =
      let
        val (cls,_) = List.foldl add ([],initialCnf) ths
      in
        List.rev cls
      end;
end;

fun cnf fm =
    let
      val cls = proveCnf [mkAxiom fm]
    in
      List.map fst cls
    end;

end
